package com.auditlog.sql.factory;

import cn.hutool.core.lang.Assert;
import com.auditlog.sql.factory.impl.*;
import com.auditlog.datasource.db.SqlType;
import net.sf.jsqlparser.statement.Statement;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 *
 * @author Zhiyang.Zhang
 * @date 2022/12/4 17:38
 * @version 1.0
 */
public class SqlRunnerFactoryRegistry {

    private final Map<SqlType, SqlRunnerFactory> sqlTypeSqlRunnerFactoryMap = new ConcurrentHashMap<>();

    private SqlRunnerFactoryRegistry() {
        register(new InsertUpdateSqlRunnerFactory());
        register(new UpdateSqlRunnerFactory());
        register(new DeleteSqlRunnerFactory());
        register(new InsertSqlRunnerFactory());
    }

    private static class SingletonHolder {
        private static final SqlRunnerFactoryRegistry INSTANCE = new SqlRunnerFactoryRegistry();

    }

    public static SqlRunnerFactoryRegistry getInstance() {
        return SqlRunnerFactoryRegistry.SingletonHolder.INSTANCE;
    }

    public void register(SqlRunnerFactory sqlRunnerFactory) {
        Assert.isTrue(sqlRunnerFactory != null, "注册的sql解析器为null");
        SqlRunnerFactory old = sqlTypeSqlRunnerFactoryMap.putIfAbsent(sqlRunnerFactory.sqlType(), sqlRunnerFactory);
        Assert.isTrue(old == null, "已存在%s对应的sql解析器", sqlRunnerFactory.sqlType());
    }

    public SqlRunnerFactory get(SqlType type) {
        return sqlTypeSqlRunnerFactoryMap.get(type);
    }

    public SqlRunnerFactory support(Statement statement) {
        return sqlTypeSqlRunnerFactoryMap.values().stream()
                .filter(sqlParser -> sqlParser.support(statement))
                .findAny()
                .orElse(new DefaultSqlRunnerFactory());
    }
}
